#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import argparse
import multiprocessing
import os
import subprocess
import time
from argparse import Namespace as Args
from typing import Dict, List, Optional
from gqlalchemy import Memgraph
import yaml
from workers import get_worker_object, get_worker_steps


# Paths
SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__))
BASE_DIR = os.path.normpath(os.path.join(SCRIPT_DIR, "..", ".."))
BUILD_DIR = os.path.join(BASE_DIR, "build")
MEASUREMENTS_FILE = os.path.join(SCRIPT_DIR, ".apollo_measurements")
KEY_FILE = os.path.join(SCRIPT_DIR, ".key.pem")
CERT_FILE = os.path.join(SCRIPT_DIR, ".cert.pem")
DEPLOYMENTS_DIR = os.path.join(SCRIPT_DIR, "configurations/deployments")

# Long running stats file
STATS_FILE = os.path.join(SCRIPT_DIR, ".long_running_stats")

# Get number of threads
THREADS = os.environ["THREADS"] if "THREADS" in os.environ else multiprocessing.cpu_count()


class Config:
    def __init__(self, config):
        self._verbose = config.get("general", {}).get("verbose", False)
        self._use_ssl = config.get("general", {}).get("use_ssl", False)

        # Memgraph configuration
        memgraph = config.get("memgraph", {})
        self._script = memgraph.get("deployment", {}).get("script", None)
        self._memgraph_args = memgraph.get("args", [])

        # Dataset tests
        dataset = config.get("dataset", {})
        self._tests = dataset.get("tests", [])

        self._custom_workloads = config.get("customWorkloads", {}).get("tests", [])

    @property
    def uses_ssl(self):
        return self._use_ssl

    @property
    def is_verbose(self):
        return self._verbose

    @property
    def script(self):
        return self._script

    @property
    def tests(self):
        return self._tests

    @property
    def custom_workloads(self):
        return self._custom_workloads

    @property
    def memgraph_args(self):
        return self._memgraph_args


def start_deployment(script_path: str, args: List[str]) -> None:
    cmd = [script_path, "start"] + args
    print(f"Starting Memgraph with: {' '.join(cmd)}")
    subprocess.run(cmd, check=True)
    time.sleep(2)  # Ensure the server has time to start


def stop_deployment(script_path: str) -> None:
    cmd = [script_path, "stop"]
    print(f"Stopping Memgraph with: {' '.join(cmd)}")
    subprocess.run(cmd, check=True)


def check_deployment_status(script_path: str) -> bool:
    cmd = [script_path, "status"]
    result = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
    return result.returncode == 0


def parse_arguments() -> Args:
    """Parse command-line arguments."""
    parser = argparse.ArgumentParser(description="Run stress tests on Memgraph.")
    parser.add_argument("--config-file", default=os.path.join(SCRIPT_DIR, "configurations/templates/config_small.yaml"))
    parser.add_argument("--python", default=os.path.join(BASE_DIR, "tests","ve3", "bin", "python3"), type=str)
    return parser.parse_args()


def run_test(args: Args, config: Config, test: str, options: List[str], timeout: int) -> float:
    """Runs a test for Memgraph."""
    print(f"Running test '{test}'")

    binary = _find_test_binary(args, config, test)

    cmd = binary + ["--worker-count", str(THREADS)] + options
    start = time.time()
    ret_test = subprocess.run(cmd, cwd=SCRIPT_DIR, timeout=timeout * 60)

    if ret_test.returncode != 0:
        raise Exception(f"Test '{test}' binary returned non-zero ({ret_test.returncode})!")

    runtime = time.time() - start
    print(f"    Done after {runtime:.3f} seconds")
    return runtime


def _find_test_binary(args: Args, config: Config, test: str) -> List[str]:
    if test.endswith(".py"):
        logging = "DEBUG" if config.is_verbose else "WARNING"
        return [args.python, "-u", os.path.join(SCRIPT_DIR, test), "--logging", logging]

    if test.endswith(".cpp"):
        exe = os.path.join(BUILD_DIR, "tests", "stress", test[:-4])
        return [exe]

    raise Exception(f"Test '{test}' binary not supported!")


def run_custom_workloads(script_path: str, config: Config) -> Optional[Dict[str, float]]:
    """Runs custom workloads on Memgraph using GQLAlchemy."""
    print("Running custom workloads...")
    runtimes = {}
    err = False

    for workload in config.custom_workloads:
        workload_name = workload["name"]
        print(f"Running workload '{workload_name}'")
        start = time.time()

        try:
            # Start Memgraph deployment
            start_deployment(script_path, config.memgraph_args + workload.get("memgraph_args", []))

            # Run import queries
            print(f"Started import for workload '{workload_name}'")
            
            query_host = workload.get("querying", {}).get("host", "127.0.0.1")
            query_port = workload.get("querying", {}).get("port", 7687)
            
            memgraph = Memgraph(host=query_host, port=query_port)
            for query in workload.get("import", {}).get("queries", []):
                print(f"Executing query: {query}")
                memgraph.execute(query)
            print(f"Finished import for workload '{workload_name}'")

            # Execute workers in steps
            processes = []
            workers = workload.get("workers", [])
            steps = get_worker_steps(workers)

            for step in steps:
                print(f"Running step {step}.")

                for worker in [x for x in workers if x.get("step", 1) == step]:              
                    # Override with workload global querying if not set on the worker
                    if "querying" not in worker.keys():
                        worker["querying"] = {"host": query_host, "port": query_port}
                    worker_object = get_worker_object(worker)

                    for i in range(worker.get("replicas", 1)):
                        process = multiprocessing.Process(target=worker_object.run)
                        processes.append((process, workload.get("timeout_min", 10) * 60))
                        process.start()

                # Monitor processes and enforce timeouts
                for process, timeout in processes:
                    process.join(timeout)
                    if process.is_alive():
                        print(f"Worker process exceeded timeout of {timeout // 60} minutes. Terminating...")
                        process.terminate()
                        process.join()

            print(f"Finished workload '{workload_name}'")
            end = time.time()
            runtimes[workload_name] = end - start

        except Exception as ex:
            print(f"Failed to execute workload {workload_name} with exception:", ex)
            err = True
        finally:
            stop_deployment(script_path)

        if err:
            return None

    return runtimes


def run_stress_test_suite(args: Args, config: Config) -> Optional[Dict[str, float]]:
    """Runs the stress test suite on Memgraph."""
    runtimes = {}

    script_path = config.script
    if not script_path:
        raise Exception("Memgraph deployment script not found!")
    script_path = os.path.join(DEPLOYMENTS_DIR, script_path)
    if not os.path.exists(script_path):
        raise Exception(f"Memgraph deployment script not found: {script_path}!")

    for test in config.tests:
        test_name = test["name"]
        test_flags = test["test_args"]
        timeout = test["timeout_min"]
        memgraph_flags = config.memgraph_args + test.get("memgraph_args", [])

        try:
            start_deployment(script_path, memgraph_flags)
        except Exception as ex:
            print("Exception occurred while starting Memgraph:", ex)
            return None

        err = False
        try:
            runtime = run_test(args, config, test_name, test_flags, timeout)
            runtimes[os.path.splitext(test_name)[0]] = runtime
        except Exception as ex:
            print(f"Failed to execute {test_name} with exception:", ex)
            err = True
        finally:
            stop_deployment(script_path)

        if err:
            return None

    workload_runtimes = run_custom_workloads(script_path, config)
    if workload_runtimes is None:
        return None

    return runtimes | workload_runtimes


def write_stats(runtimes: Dict[str, float]) -> None:
    """Writes stress test results to a file."""
    with open(MEASUREMENTS_FILE, "w") as f:
        f.write("\n".join(f"{key}.runtime {value}" for key, value in runtimes.items()))


if __name__ == "__main__":
    args = parse_arguments()
    config = Config(yaml.safe_load(open(args.config_file)))

    runtimes = run_stress_test_suite(args, config)

    if runtimes is None:
        print("Some stress tests have failed")
        exit(1)

    write_stats(runtimes)
    print("Successfully ran stress tests!")
